import torch
import torch.nn as nn
import torch.nn.functional as F

from loss import loss


def loss_fn(alpha, beta, l, s1, s2, s3, emb, y, emb_true):
    l1 = (1 - alpha - beta) * loss(s1, y)
    l2 = alpha * loss(s2, y)
    l3 = beta * loss(s3, y)

    return l1 + l2 + l3 + l * F.cross_entropy(emb, emb_true)
    pass
